--- external/eg3d/training/volumetric_rendering/renderer.py	2023-03-14 23:32:13.458508000 +0000
+++ external_reference/eg3d/training/volumetric_rendering/renderer.py	2023-03-14 23:28:05.439774713 +0000
@@ -17,8 +17,9 @@
 import torch
 import torch.nn as nn
 
-from training.volumetric_rendering.ray_marcher import MipRayMarcher2
-from training.volumetric_rendering import math_utils
+from external.eg3d.training.volumetric_rendering.ray_marcher import MipRayMarcher2
+from external.eg3d.training.volumetric_rendering import math_utils
+from utils import noise_util
 
 def generate_planes():
     """
@@ -26,15 +27,15 @@
     plane. Should work with arbitrary number of planes and planes of
     arbitrary orientation.
     """
-    return torch.tensor([[[1, 0, 0],
+    return torch.tensor([[[1, 0, 0], # XY
                             [0, 1, 0],
                             [0, 0, 1]],
-                            [[1, 0, 0],
+                            [[1, 0, 0], #XZ
                             [0, 0, 1],
                             [0, 1, 0]],
-                            [[0, 0, 1],
-                            [1, 0, 0],
-                            [0, 1, 0]]], dtype=torch.float32)
+                            [[0, 1, 0], # YZ
+                            [0, 0, 1],
+                            [1, 0, 0]]], dtype=torch.float32)
 
 def project_onto_planes(planes, coordinates):
     """
@@ -50,16 +51,17 @@
     coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
     inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
     projections = torch.bmm(coordinates, inv_planes)
-    return projections[..., :2]
+    return projections[..., :2] # projections are ordered (0,1,2) corresp to batch [0]
 
-def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
-    assert padding_mode == 'zeros'
-    N, n_planes, C, H, W = plane_features.shape
-    _, M, _ = coordinates.shape
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='reflection', box_warp=None):
+    assert padding_mode == 'reflection' 
+    N, n_planes, C, H, W = plane_features.shape # bs x num_samples x C x H x W
+    _, M, _ = coordinates.shape # bs x (render_h * render_w * samples) x 3
     plane_features = plane_features.view(N*n_planes, C, H, W)
 
     coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
 
+    # shape = [bs*n_planes, 1, render_h * render_w * samples, 2]
     projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
     output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
     return output_features
@@ -85,7 +87,7 @@
         self.ray_marcher = MipRayMarcher2()
         self.plane_axes = generate_planes()
 
-    def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
+    def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, noise_input=None):
         self.plane_axes = self.plane_axes.to(ray_origins.device)
 
         if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
@@ -94,10 +96,10 @@
             if torch.any(is_ray_valid).item():
                 ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
                 ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
-            depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+            depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
         else:
             # Create stratified depth samples
-            depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+            depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
 
         batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
 
@@ -105,19 +107,30 @@
         sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
         sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
 
+        # prevent rays from exceeding Y clip value
+        # removes sky blobs when extending far bound at inference
+        if rendering_options['y_clip'] is not None:
+            # limit the depth of the ray so that it does not surpass y_clip
+            y_clip = rendering_options['y_clip']
+            max_depth = (y_clip - ray_origins[..., 1]) / ray_directions[..., 1]
+            max_depth = max_depth[..., None, None] # B, HW, num_samples, 1)
+            depths_clip = torch.where(max_depth > 0, torch.minimum(depths_coarse, max_depth), depths_coarse)
+            depths_coarse = depths_clip # replace coarse depths with clipped depth
+            sample_coordinates = (ray_origins.unsqueeze(-2) + depths_clip * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
 
         out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
         colors_coarse = out['rgb']
         densities_coarse = out['sigma']
         colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1])
         densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1)
+        noise_coarse = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, samples_per_ray, 1)
 
         # Fine Pass
         N_importance = rendering_options['depth_resolution_importance']
         if N_importance > 0:
-            _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+            _, _, _, weights, _ = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
 
-            depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+            depths_fine = self.sample_importance(depths_coarse, weights, N_importance, rendering_options['sample_deterministic'])
 
             sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
             sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
@@ -127,20 +140,19 @@
             densities_fine = out['sigma']
             colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1])
             densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1)
+            noise_fine = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, N_importance, 1)
 
-            all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
-                                                                  depths_fine, colors_fine, densities_fine)
+            all_depths, all_colors, all_densities, all_noise = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, noise_coarse, depths_fine, colors_fine, densities_fine, noise_fine)
 
             # Aggregate
-            rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
+            rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, all_noise)
         else:
-            rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
-
+            rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
 
-        return rgb_final, depth_final, weights.sum(2)
+        return rgb_final, depth_final, disp_final, weights.sum(2), noise_final
 
     def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
-        sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+        sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='reflection', box_warp=options['box_warp'])
 
         out = decoder(sampled_features, sample_directions)
         if options.get('density_noise', 0) > 0:
@@ -154,19 +166,22 @@
         all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
         return all_depths, all_colors, all_densities
 
-    def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+    def unify_samples(self, depths1, colors1, densities1, noise1,
+                      depths2, colors2, densities2, noise2):
         all_depths = torch.cat([depths1, depths2], dim = -2)
         all_colors = torch.cat([colors1, colors2], dim = -2)
         all_densities = torch.cat([densities1, densities2], dim = -2)
+        all_noise = torch.cat([noise1, noise2], dim = -2)
 
         _, indices = torch.sort(all_depths, dim=-2)
         all_depths = torch.gather(all_depths, -2, indices)
         all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
         all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+        all_noise = torch.gather(all_noise, -2, indices.expand(-1, -1, -1, 1))
 
-        return all_depths, all_colors, all_densities
+        return all_depths, all_colors, all_densities, all_noise
 
-    def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+    def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False, sample_deterministic=False):
         """
         Return depths of approximately uniformly spaced samples along rays.
         """
@@ -177,21 +192,30 @@
                                     depth_resolution,
                                     device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
             depth_delta = 1/(depth_resolution - 1)
-            depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+            if not sample_deterministic:
+                depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+            else:
+                depths_coarse += 0.5 * depth_delta
             depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
         else:
             if type(ray_start) == torch.Tensor:
                 depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
                 depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
-                depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+                if not sample_deterministic:
+                    depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+                else:
+                    depths_coarse += 0.5 * depth_delta[..., None]
             else:
                 depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
                 depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
-                depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+                if not sample_deterministic:
+                    depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+                else:
+                    depths_coarse += 0.5 * depth_delta
 
         return depths_coarse
 
-    def sample_importance(self, z_vals, weights, N_importance):
+    def sample_importance(self, z_vals, weights, N_importance, sample_deterministic=False):
         """
         Return depths of importance sampled points along rays. See NeRF importance sampling for more.
         """
@@ -207,8 +231,7 @@
             weights = weights + 0.01
 
             z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
-            importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
-                                             N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+            importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], N_importance, det=sample_deterministic).detach().reshape(batch_size, num_rays, N_importance, 1)
         return importance_z_vals
 
     def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
@@ -250,4 +273,4 @@
                              # anyway, therefore any value for it is fine (set to 1 here)
 
         samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
-        return samples
\ No newline at end of file
+        return samples
